{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Solving Frozen Lake Problem Using Policy Iteration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font size=3.5> Imagine, there is a frozen lake from your home to office, you should walk on the frozen lake\n",
    "to reach your office. But oops! there will be a hole in the frozen lake in between, so you have\n",
    "to be careful while walking in the frozen lake to avoid getting trapped at holes.\n",
    "Look at the below figure where, \n",
    "\n",
    "1. S is the starting position (Home)\n",
    "2. F is the Frozen lake where you can walk\n",
    "3. H is the Hole which you have to be so careful about\n",
    "4. G is the Goal (office)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![title](images/B09792_03_50.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font size = 3.5> Okay, now let us use our agent instead of you to find the correct way to reach the office.\n",
    "The agent goal is to find the optimal path to reach from S to G without getting trapped at H.\n",
    "How an agent can achieve this? We give +1 point as a reward to the agent if it correctly\n",
    "walks on the frozen lake and 0 points if it falls into the hole. So that agent could determine\n",
    "which is the right action. An agent will now try to find the optimal policy. Optimal policy\n",
    "implies taking the correct path which maximizes the agent reward. If the agent is\n",
    "maximizing the reward, apparently agent is learning to skip the hole and reach the\n",
    "destination."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font size=3.5> First, we import necessary libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font size=3.5> Initialize our gym environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2018-06-01 12:00:11,336] Making new env: FrozenLake-v0\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('FrozenLake-v0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font size=3.5> Let us see how the environment looks like"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[41mS\u001b[0mFFF\n",
      "FHFH\n",
      "FFFH\n",
      "HFFG\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<ipykernel.iostream.OutStream at 0x7fae425afbe0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font size=3.5> Now, we will compute the value function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def compute_value_function(policy, gamma=1.0):\n",
    "    \n",
    "    # initialize value table with zeros\n",
    "    value_table = np.zeros(env.nS)\n",
    "    \n",
    "    # set the threshold\n",
    "    threshold = 1e-10\n",
    "    \n",
    "    while True:\n",
    "        \n",
    "        # copy the value table to the updated_value_table\n",
    "        updated_value_table = np.copy(value_table)\n",
    "\n",
    "        # for each state in the environment, select the action according to the policy and compute the value table\n",
    "        for state in range(env.nS):\n",
    "            action = policy[state]\n",
    "            \n",
    "            # build the value table with the selected action\n",
    "            value_table[state] = sum([trans_prob * (reward_prob + gamma * updated_value_table[next_state]) \n",
    "                        for trans_prob, next_state, reward_prob, _ in env.P[state][action]])\n",
    "            \n",
    "        if (np.sum((np.fabs(updated_value_table - value_table))) <= threshold):\n",
    "            break\n",
    "            \n",
    "    return value_table\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font size=3.5> Now, we define a function called extract policy for extracting optimal policy from the optimal value function. \n",
    "i.e We calculate Q value using our optimal value function and pick up\n",
    "the actions which has the highest Q value for each state as the optimal policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "def extract_policy(value_table, gamma = 1.0):\n",
    " \n",
    "    # Initialize the policy with zeros\n",
    "    policy = np.zeros(env.observation_space.n) \n",
    "    \n",
    "    \n",
    "    for state in range(env.observation_space.n):\n",
    "        \n",
    "        # initialize the Q table for a state\n",
    "        Q_table = np.zeros(env.action_space.n)\n",
    "        \n",
    "        # compute Q value for all ations in the state\n",
    "        for action in range(env.action_space.n):\n",
    "            for next_sr in env.P[state][action]: \n",
    "                trans_prob, next_state, reward_prob, _ = next_sr \n",
    "                Q_table[action] += (trans_prob * (reward_prob + gamma * value_table[next_state]))\n",
    "        \n",
    "        # Select the action which has maximum Q value as an optimal action of the state\n",
    "        policy[state] = np.argmax(Q_table)\n",
    "    \n",
    "    return policy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<font size=3.5> Now, we define the function for performing policy iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy_iteration(env,gamma = 1.0):\n",
    "    \n",
    "    # Initialize policy with zeros\n",
    "    old_policy = np.zeros(env.observation_space.n)   \n",
    "    no_of_iterations = 200000\n",
    "    \n",
    "    for i in range(no_of_iterations):\n",
    "        \n",
    "        # compute the value function\n",
    "        new_value_function = compute_value_function(old_policy, gamma)\n",
    "        \n",
    "        # Extract new policy from the computed value function\n",
    "        new_policy = extract_policy(new_value_function, gamma)\n",
    "   \n",
    "        # Then we check whether we have reached convergence i.e whether we found the optimal\n",
    "        # policy by comparing old_policy and new policy if it same we will break the iteration\n",
    "        # else we update old_policy with new_policy\n",
    "\n",
    "        if (np.all(old_policy == new_policy)):\n",
    "            print ('Policy-Iteration converged at step %d.' %(i+1))\n",
    "            break\n",
    "        old_policy = new_policy\n",
    "        \n",
    "    return new_policy\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Policy-Iteration converged at step 7.\n",
      "[0. 3. 3. 3. 0. 0. 0. 0. 3. 1. 0. 0. 0. 2. 1. 0.]\n"
     ]
    }
   ],
   "source": [
    "print (policy_iteration(env))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:universe]",
   "language": "python",
   "name": "conda-env-universe-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
